archived/rl_gamerserver_ray/common/sagemaker_rl/ray_launcher.py (311 lines of code) (raw):
import json
import os
import subprocess
import sys
import time
from enum import Enum
from shutil import copyfile
import boto3
import ray
from ray.tune import run_experiments
from .configuration_list import ConfigurationList
from .docker_utils import get_ip_from_host
from .sage_cluster_communicator import SageClusterCommunicator
from .tf_serving_utils import change_permissions_recursive, export_tf_serving, natural_keys
TERMINATION_SIGNAL = "JOB_TERMINATED"
INTERMEDIATE_DIR = "/opt/ml/output/intermediate"
CHECKPOINT_DIR = "/opt/ml/input/data/checkpoint"
MODEL_OUTPUT_DIR = "/opt/ml/model"
class Cluster(Enum):
"""
Used when training is done in heterogeneous mode, i.e. 2 SageMaker jobs are launched with
different instance types. Usually, primary cluster has a single GPU instance responsible
for Neural Network training and secondary cluster has CPU instances for rollouts.
For single machine or homogeneous cluster, primary is the default type.
"""
Primary = "primary"
Secondary = "secondary"
class SageMakerRayLauncher(object):
"""Base class for SageMaker RL applications using Ray-RLLib.
Customers should sub-class this, fill in the required methods, and
call .train_main() to start a training process.
Example::
def create_environment(env_config):
# Import must happen inside the method so workers re-import
import roboschool
return gym.make('RoboschoolHumanoid-v1')
class MyLauncher(SageMakerRayLauncher):
def register_env_creator(self):
register_env("RoboschoolHumanoid-v1", create_environment)
def get_experiment_config(self):
return {
"training": {
"env": "RoboschoolHumanoid-v1",
"run": "PPO",
...
}
}
if __name__ == "__main__":
MyLauncher().train_main()
"""
def __init__(self):
self.num_cpus = int(os.environ.get("SM_NUM_CPUS", 1))
self.num_gpus = int(os.environ.get("SM_NUM_GPUS", 0))
self.cluster_type = self._get_cluster_type()
self.num_instances_secondary_cluster = int(
os.environ.get("SM_HP_RL_NUM_INSTANCES_SECONDARY", 0)
)
self.host_name = os.environ.get("SM_CURRENT_HOST", "algo-1")
self.hosts_info = json.loads(os.environ.get("SM_RESOURCE_CONFIG"))["hosts"]
self.is_master_node = (
self.hosts_info[0] == self.host_name and self.cluster_type == Cluster.Primary
)
self.sage_cluster_communicator = SageClusterCommunicator()
def _get_cluster_type(self):
cluster_str = os.environ.get("SM_HP_RL_CLUSTER_TYPE", "primary")
if cluster_str.lower() == "primary":
return Cluster.Primary
else:
return Cluster.Secondary
def register_env_creator(self):
"""Sub-classes must implement this."""
raise NotImplementedError(
"Subclasses should implement this to call ray.tune.registry.register_env"
)
def get_experiment_config(self):
raise NotImplementedError(
"Subclasses must define the experiment config to pass to ray.tune.run_experiments"
)
def customize_experiment_config(self, config):
"""Applies command-line hyperparameters to the config."""
# TODO: use ConfigList from Coach launcher, and share customization code.
hyperparams_dict = json.loads(os.environ.get("SM_HPS", "{}"))
# Set output dir to intermediate
# TODO: move this to before customer-specified so they can override
hyperparams_dict["rl.training.local_dir"] = INTERMEDIATE_DIR
hyperparams_dict["rl.training.checkpoint_at_end"] = True
hyperparams_dict["rl.training.checkpoint_freq"] = config["training"].get(
"checkpoint_freq", 10
)
self.hyperparameters = ConfigurationList() # TODO: move to shared
for name, value in hyperparams_dict.items():
# self.map_hyperparameter(name, val) #TODO
if name.startswith("rl."):
# self.apply_hyperparameter(name, value) #TODO
self.hyperparameters.store(name, value)
# else:
# raise ValueError("Unknown hyperparameter %s" % name)
self.hyperparameters.apply_subset(config, "rl.")
return config
def get_all_host_names(self):
all_workers_host_names = []
for host in self.hosts_info:
# All primary cluster instances' hostnames. Prefix with "primary"
all_workers_host_names.append("%s:%s" % (self.cluster_type.value, host))
for i in range(self.num_instances_secondary_cluster):
# All secondary cluster instances' hostnames. Prefix with "secondary"
all_workers_host_names.append("%s:algo-%s" % (Cluster.Secondary.value, i + 1))
return all_workers_host_names
def ray_init_config(self):
num_workers = max(self.num_cpus, 3)
config = {"num_cpus": num_workers, "num_gpus": self.num_gpus}
if self.is_master_node:
all_workers_host_names = self.get_all_host_names()[1:]
# Single machine job
if len(all_workers_host_names) == 0:
return config
master_ip = get_ip_from_host(host_name=self.host_name)
self.start_ray_cluster(master_ip)
self.sage_cluster_communicator.write_host_config(
ip=master_ip, host_name="%s:%s" % (self.cluster_type.value, self.host_name)
)
self.sage_cluster_communicator.create_s3_signal(
"%s:%s" % (self.cluster_type.value, self.host_name)
)
print("Waiting for %s worker nodes to join!" % (len(all_workers_host_names)))
self.sage_cluster_communicator.wait_for_signals(all_workers_host_names)
print("All worker nodes have joined the cluster. Now training...")
if ray.__version__ >= "0.8.2":
config = {"address": "%s:6379" % master_ip}
else:
config = {"redis_address": "%s:6379" % master_ip}
else:
master_ip, master_hostname = self.sage_cluster_communicator.get_master_config()
node_ip = get_ip_from_host(host_name=self.host_name)
self.sage_cluster_communicator.wait_for_signals([master_hostname])
print("Attempting to join ray cluster.")
self.join_ray_cluster(master_ip, node_ip)
self.sage_cluster_communicator.create_s3_signal(
"%s:%s" % (self.cluster_type.value, self.host_name)
)
print("Joined ray cluster at %s successfully!" % master_ip)
self.sage_cluster_communicator.wait_for_signals(
[TERMINATION_SIGNAL], timeout=sys.maxsize
)
print("Received job termination signal. Shutting down.")
return config
def start_ray_cluster(self, master_ip):
if ray.__version__ >= "1.0.0": # port starts to be used after ray 1.0.0
p = subprocess.Popen(
"ray start --head --port=6379 --node-ip-address=%s" % master_ip,
shell=True,
stderr=subprocess.STDOUT,
)
elif ray.__version__ >= "0.6.5":
p = subprocess.Popen(
"ray start --head --redis-port=6379 --node-ip-address=%s" % master_ip,
shell=True,
stderr=subprocess.STDOUT,
)
else:
p = subprocess.Popen(
"ray start --head --redis-port=6379 --no-ui --node-ip-address=%s" % master_ip,
shell=True,
stderr=subprocess.STDOUT,
)
time.sleep(3)
if p.poll() != 0:
raise RuntimeError("Could not start Ray server.")
def join_ray_cluster(self, master_ip, node_ip):
if ray.__version__ >= "0.8.2":
p = subprocess.Popen(
"ray start --address=%s:6379" % (master_ip),
shell=True,
stderr=subprocess.STDOUT,
stdout=subprocess.PIPE,
)
else:
p = subprocess.Popen(
"ray start --redis-address=%s:6379 --node-ip-address=%s" % (master_ip, node_ip),
shell=True,
stderr=subprocess.STDOUT,
)
time.sleep(3)
if p.poll() != 0:
raise RuntimeError("Could not join Ray server running at %s:6379" % master_ip)
def copy_checkpoints_to_model_output(self):
checkpoints = []
count = 0
while not checkpoints:
count += 1
for root, directories, filenames in os.walk(INTERMEDIATE_DIR):
for filename in filenames:
if filename.startswith("checkpoint"):
checkpoints.append(os.path.join(root, filename))
time.sleep(5)
if count >= 6:
raise RuntimeError("Failed to find checkpoint files")
checkpoints.sort(key=natural_keys)
latest_checkpoints = checkpoints[-2:]
validation = sum(
1 if x.endswith("tune_metadata") or x.endswith("extra_data") else 0
for x in latest_checkpoints
)
if ray.__version__ >= "0.6.5":
if validation is not 1:
raise RuntimeError("Failed to save checkpoint files - .tune_metadata")
else:
if validation is not 2:
raise RuntimeError(
"Failed to save checkpoint files - .tune_metadata or .extra_data"
)
for source_path in latest_checkpoints:
_, ext = os.path.splitext(source_path)
destination_path = os.path.join(MODEL_OUTPUT_DIR, "checkpoint%s" % ext)
copyfile(source_path, destination_path)
print("Saved the checkpoint file %s as %s" % (source_path, destination_path))
def save_experiment_config(self):
config_found = False
for root, directories, filenames in os.walk(INTERMEDIATE_DIR):
if config_found:
break
else:
for filename in filenames:
if filename == "params.json":
source = os.path.join(root, filename)
config_found = True
copyfile(source, os.path.join(MODEL_OUTPUT_DIR, "params.json"))
print("Saved model configuration.")
def create_tf_serving_model(self, algorithm=None, env_string=None):
self.register_env_creator()
if ray.__version__ >= "0.6.5":
from ray.rllib.agents.registry import get_agent_class
else:
from ray.rllib.agents.agent import get_agent_class
cls = get_agent_class(algorithm)
with open(os.path.join(MODEL_OUTPUT_DIR, "params.json")) as config_json:
config = json.load(config_json)
print("Loaded config for TensorFlow serving.")
config["monitor"] = False
config["num_workers"] = 1
config["num_gpus"] = 0
agent = cls(env=env_string, config=config)
checkpoint = os.path.join(MODEL_OUTPUT_DIR, "checkpoint")
agent.restore(checkpoint)
export_tf_serving(agent, MODEL_OUTPUT_DIR)
def save_checkpoint_and_serving_model(self, algorithm=None, env_string=None, use_pytorch=False):
self.save_experiment_config()
self.copy_checkpoints_to_model_output()
if use_pytorch:
print("Skipped PyTorch serving.")
else:
self.create_tf_serving_model(algorithm, env_string)
# To ensure SageMaker local mode works fine
change_permissions_recursive(INTERMEDIATE_DIR, 0o777)
change_permissions_recursive(MODEL_OUTPUT_DIR, 0o777)
def set_up_checkpoint(self, config=None):
try:
checkpoint_dir = config["training"]["restore"]
print("Found checkpoint dir %s in user config." % checkpoint_dir)
return config
except KeyError:
pass
if not os.path.exists(CHECKPOINT_DIR):
print("No checkpoint path specified. Training from scratch.")
return config
checkpoint_dir = self._checkpoint_dir_finder(CHECKPOINT_DIR)
# validate the contents
print("checkpoint_dir is {}".format(checkpoint_dir))
checkpoint_dir_contents = os.listdir(checkpoint_dir)
if len(checkpoint_dir_contents) not in [2, 3]:
raise RuntimeError(
f"Unexpected files {checkpoint_dir_contents} in checkpoint dir. "
"Please check ray documents for the correct checkpoint format."
)
validation = 0
checkpoint_file_in_container = ""
for filename in checkpoint_dir_contents:
is_tune_metadata = filename.endswith("tune_metadata")
is_extra_data = filename.endswith("extra_data")
is_checkpoint_meta = is_tune_metadata + is_extra_data
validation += is_checkpoint_meta
if not is_checkpoint_meta:
checkpoint_file_in_container = os.path.join(checkpoint_dir, filename)
if ray.__version__ >= "0.6.5":
if validation is not 1:
raise RuntimeError("Failed to find .tune_metadata to restore checkpoint.")
else:
if validation is not 2:
raise RuntimeError(
"Failed to find .tune_metadata or .extra_data to restore checkpoint"
)
if checkpoint_file_in_container:
print(
"Found checkpoint: %s. Setting `restore` path in ray config."
% checkpoint_file_in_container
)
config["training"]["restore"] = checkpoint_file_in_container
else:
print("No valid checkpoint found in %s. Training from scratch." % checkpoint_dir)
return config
def _checkpoint_dir_finder(self, current_dir=None):
current_dir_subfolders = os.walk(current_dir).__next__()[1]
if len(current_dir_subfolders) > 1:
raise RuntimeError(
f"Multiple folders detected: '{current_dir_subfolders}'."
"Please provide one checkpoint only."
)
elif not current_dir_subfolders:
return current_dir
return self._checkpoint_dir_finder(os.path.join(current_dir, *current_dir_subfolders))
def launch(self):
"""Actual entry point into the class instance where everything happens.
Lots of delegating to classes that are in subclass or can be over-ridden.
"""
self.register_env_creator()
# All worker nodes will block at this step during training
ray_cluster_config = self.ray_init_config()
if not self.is_master_node:
return
# Start the driver on master node
ray.init(**ray_cluster_config)
experiment_config = self.get_experiment_config()
experiment_config = self.customize_experiment_config(experiment_config)
experiment_config = self.set_up_checkpoint(experiment_config)
print(
'Important! Ray with version <=0.7.2 may report "Did not find checkpoint file" even if the',
"experiment is actually restored successfully. If restoration is expected, please check",
'"training_iteration" in the experiment info to confirm.',
)
# Before and equal to Ray 0.8.5, "use_pytorch" was used to specify framework. After Ray 0.8.5, "framework" is used.
if experiment_config["training"]["config"].get(
"framework", False
): # if "framework" is used
use_pytorch = (
True if experiment_config["training"]["config"]["framework"] == "torch" else False
)
if ray.__version__ <= "0.8.5":
experiment_config["training"]["config"]["use_pytorch"] = use_pytorch
del experiment_config["training"]["config"]["framework"]
else: # if "use_pytorch" is used or no framework specified
use_pytorch = experiment_config["training"]["config"].get("use_pytorch", False)
if ray.__version__ > "0.8.5":
experiment_config["training"]["config"]["framework"] = (
"torch" if use_pytorch else "tf"
)
experiment_config["training"]["config"].pop("use_pytorch", None)
run_experiments(experiment_config)
all_workers_host_names = self.get_all_host_names()[1:]
# If distributed job, send TERMINATION_SIGNAL to all workers.
if len(all_workers_host_names) > 0:
self.sage_cluster_communicator.create_s3_signal(TERMINATION_SIGNAL)
algo = experiment_config["training"]["run"]
env_string = experiment_config["training"]["config"]["env"]
self.save_checkpoint_and_serving_model(
algorithm=algo, env_string=env_string, use_pytorch=use_pytorch
)
@classmethod
def train_main(cls):
"""main function that kicks things off"""
launcher = cls()
launcher.launch()